﻿#pragma kernel UpdateRopeMesh

#include "PathFrame.cginc"

struct smootherPathData
{
    uint smoothing;
    float decimation;
    float twist;
    float restLength;
    float smoothLength;
    uint usesOrientedParticles;
};

struct extrudedMeshData
{
    int sectionVertexCount;
    float thicknessScale;
    float uvAnchor;
    uint normalizeV;
    float2 uvScale;
};

StructuredBuffer<int> pathSmootherIndices;
StructuredBuffer<int> chunkOffsets;

StructuredBuffer<pathFrame> frames;
StructuredBuffer<int> frameOffsets;
StructuredBuffer<int> frameCounts;

StructuredBuffer<float2> sectionData;
StructuredBuffer<int> sectionOffsets;
StructuredBuffer<int> sectionIndices;

StructuredBuffer<int> vertexOffsets;
StructuredBuffer<int> triangleOffsets;
StructuredBuffer<int> triangleCounts;

StructuredBuffer<extrudedMeshData> extrudedData;
StructuredBuffer<smootherPathData> pathData;

RWByteAddressBuffer vertices;
RWByteAddressBuffer tris;

// Variables set from the CPU
uint firstRenderer;
uint rendererCount;

pathFrame LookAt(pathFrame frame, in pathFrame target, out float dist)
{
    float3 tangent = target.position - frame.position;
    dist = length(tangent);
    tangent /= dist + EPSILON;

    quaternion rotQ = from_to_rotation(frame.tangent, tangent);
    frame.normal = rotate_vector(rotQ, frame.normal);
    frame.binormal = rotate_vector(rotQ, frame.binormal);
    frame.tangent = tangent;

    return frame;
}

[numthreads(128, 1, 1)]
void UpdateRopeMesh (uint3 id : SV_DispatchThreadID) 
{
    unsigned int u = id.x;
    if (u >= rendererCount) return;

    int k = firstRenderer + u;
    int s = pathSmootherIndices[k];
    
    float3 vertex = float3(0,0,0);
    float3 normal = float3(0,0,0);
    float4 texTangent = FLOAT4_ZERO;

    int tri = 0;
    int sectionIndex = 0;
    int sectionStart = sectionOffsets[sectionIndices[k]];
    int sectionSegments = (sectionOffsets[sectionIndices[k] + 1] - sectionStart) - 1;
    int verticesPerSection = sectionSegments + 1;   // the last vertex in each section must be duplicated, due to uv wraparound.

    float smoothLength = 0;
    int i;
    for (i = chunkOffsets[s]; i < chunkOffsets[s + 1]; ++i)
        smoothLength += pathData[i].smoothLength;

    float vCoord = -extrudedData[k].uvScale.y * pathData[chunkOffsets[s]].restLength * extrudedData[k].uvAnchor;
    float actualToRestLengthRatio = smoothLength / pathData[chunkOffsets[s]].restLength;

    int firstVertex = vertexOffsets[k];
    int firstTriangle = triangleOffsets[k];

    // clear out triangle indices for this rope:
    for (i = firstTriangle; i < firstTriangle + triangleCounts[k]; ++i)
    {
        int offset = i*3;
        tris.Store((offset)<<2,  0);
        tris.Store((offset+1)<<2,0);
        tris.Store((offset+2)<<2,0);
    }
    
    // for each chunk in the rope:
    for (i = chunkOffsets[s]; i < chunkOffsets[s+1]; ++i)
    {
        int firstFrame = frameOffsets[i];
        int frameCount = frameCounts[i];
       
        for (int f = 0; f < frameCount; ++f)
        {
            // Calculate previous and next curve indices:
            int prevIndex = firstFrame + max(f - 1, 0);
            int index = firstFrame + f;

            // advance v texcoord:
            vCoord += extrudedData[k].uvScale.y * (distance(frames[index].position, frames[prevIndex].position) /
                                              (extrudedData[k].normalizeV ? smoothLength : actualToRestLengthRatio));

            // calculate section thickness and scale the basis vectors by it:
            float sectionThickness = frames[index].thickness * extrudedData[k].thicknessScale;
            
            // Loop around each segment:
            int nextSectionIndex = sectionIndex + 1;
            for (int j = 0; j <= sectionSegments; ++j)
            {
                // make just one copy of the section vertex:
                float2 sectionVertex = sectionData[sectionStart + j];

                // calculate normal using section vertex, curve normal and binormal:
                normal.x = (sectionVertex.x * frames[index].normal.x + sectionVertex.y * frames[index].binormal.x) * sectionThickness;
                normal.y = (sectionVertex.x * frames[index].normal.y + sectionVertex.y * frames[index].binormal.y) * sectionThickness;
                normal.z = (sectionVertex.x * frames[index].normal.z + sectionVertex.y * frames[index].binormal.z) * sectionThickness;

                // offset curve position by normal:
                vertex.x = frames[index].position.x + normal.x;
                vertex.y = frames[index].position.y + normal.y;
                vertex.z = frames[index].position.z + normal.z;

                // cross(normal, curve tangent)
                texTangent.xyz = cross(normal, frames[index].tangent);
                texTangent.w = -1;
                
                int base = (firstVertex + sectionIndex * verticesPerSection + j) * 16;
                vertices.Store3( base<<2, asuint(vertex));
                vertices.Store3((base + 3)<<2, asuint(normal));
                vertices.Store4((base + 6)<<2, asuint(texTangent));
                vertices.Store4((base + 10)<<2, asuint(frames[index].color));
                vertices.Store2((base + 14)<<2, asuint(float2(j / (float)sectionSegments * extrudedData[k].uvScale.x, vCoord)));

                if (j < sectionSegments && f < frameCount - 1)
                {
                    int offset = firstTriangle * 3;
                    tris.Store((offset + tri++)<<2, asuint(firstVertex + sectionIndex * verticesPerSection + j));
                    tris.Store((offset + tri++)<<2, asuint(firstVertex + nextSectionIndex * verticesPerSection + j));
                    tris.Store((offset + tri++)<<2, asuint(firstVertex + sectionIndex * verticesPerSection + (j + 1)));

                    tris.Store((offset + tri++)<<2, asuint(firstVertex + sectionIndex * verticesPerSection + (j + 1)));
                    tris.Store((offset + tri++)<<2, asuint(firstVertex + nextSectionIndex * verticesPerSection + j));
                    tris.Store((offset + tri++)<<2, asuint(firstVertex + nextSectionIndex * verticesPerSection + (j + 1)));
                }
            }
            sectionIndex++;
        }
    }
}